# Copyright 2020 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

# ***THIS FILE DOES NOT BUILD WITH BAZEL***
#
# It is open sourced to enable Bazel->CMake conversion to maintain test coverage
# of our integration tests in open source while we figure out a long term plan
# for our integration testing.

load(
    "@iree//integrations/tensorflow/e2e:iree_e2e_cartesian_product_test_suite.bzl",
    "iree_e2e_cartesian_product_test_suite",
)

package(
    default_visibility = ["//visibility:public"],
    features = ["layering_check"],
    licenses = ["notice"],  # Apache 2.0
)

py_binary(
    name = "math_test_manual",
    srcs = ["math_test.py"],
    main = "math_test.py",
    python_version = "PY3",
    deps = [
        "//third_party/py/absl:app",
        "//third_party/py/absl/flags",
        "//third_party/py/iree:pylib_tf_support",
        "//third_party/py/numpy",
        "//third_party/py/tensorflow",
        "//util/debuginfo:signalsafe_addr2line_installer",
    ],
)

# These functions were selected using all of the funcions in the tf.math docs:
#   https://www.tensorflow.org/api_docs/python/tf/math
TF_MATH_FUNCTIONS = [
    "abs",
    "accumulate_n",
    "acos",
    "acosh",
    "add",
    "add_n",
    "angle",
    "argmax",
    "argmin",
    "asin",
    "asinh",
    "atan",
    "atan2",
    "atanh",
    "bessel_i0",
    "bessel_i0e",
    "bessel_i1",
    "bessel_i1e",
    "betainc",
    "bincount",
    "ceil",
    "confusion_matrix",
    "cos",
    "cosh",
    "count_nonzero",
    "cumprod",
    "cumsum",
    "cumulative_logsumexp",
    "digamma",
    "divide",
    "divide_no_nan",
    "equal",
    "erf",
    "erfc",
    "erfinv",
    "exp",
    "expm1",
    "floor",
    "floordiv",
    "floormod",
    "greater",
    "greater_equal",
    "igamma",
    "igammac",
    "imag",
    "in_top_k",
    "invert_permutation",
    "is_finite",
    "is_inf",
    "is_nan",
    "is_non_decreasing",
    "is_strictly_increasing",
    "lbeta",
    "less",
    "less_equal",
    "lgamma",
    "log",
    "log1p",
    "log_sigmoid",
    "log_softmax",
    "logical_and",
    "logical_not",
    "logical_or",
    "logical_xor",
    "maximum",
    "minimum",
    "mod",
    "multiply",
    "multiply_no_nan",
    "ndtri",
    "negative",
    "nextafter",
    "not_equal",
    "polygamma",
    "polyval",
    "pow",
    "real",
    "reciprocal",
    "reciprocal_no_nan",
    "reduce_all",
    "reduce_any",
    "reduce_euclidean_norm",
    "reduce_logsumexp",
    "reduce_max",
    "reduce_mean",
    "reduce_min",
    "reduce_prod",
    "reduce_std",
    "reduce_sum",
    "reduce_variance",
    "rint",
    "round",
    "rsqrt",
    "scalar_mul",
    "segment_max",
    "segment_mean",
    "segment_min",
    "segment_prod",
    "segment_sum",
    "sigmoid",
    "sign",
    "sin",
    "sinh",
    "sobol_sample",
    "softmax",
    "softplus",
    "softsign",
    "sqrt",
    "square",
    "squared_difference",
    "subtract",
    "tan",
    "tanh",
    "top_k",
    "truediv",
    "unsorted_segment_max",
    "unsorted_segment_mean",
    "unsorted_segment_min",
    "unsorted_segment_prod",
    "unsorted_segment_sqrt_n",
    "unsorted_segment_sum",
    "xdivy",
    "xlog1py",
    "xlogy",
    "zero_fraction",
    "zeta",
]

# ---- STATIC TESTS ----------------------------------------------- #

# keep sorted
LLVM_FAILING = [
    "acos",
    "argmax",
    "argmin",
    "asin",
    "atan",
    "atan2",
    "bessel_i0",
    "bessel_i0e",
    "bessel_i1",
    "bessel_i1e",
    "betainc",
    "bincount",
    "confusion_matrix",
    "count_nonzero",
    "cumprod",
    "cumsum",
    "cumulative_logsumexp",
    "divide",  # Failing for integer inputs because iree doesn't output 'f64'.
    "erfinv",
    "igamma",
    "igammac",
    "in_top_k",
    "invert_permutation",
    "is_non_decreasing",
    "is_strictly_increasing",
    "ndtri",
    "nextafter",
    "reduce_all",
    "reduce_any",
    "reduce_euclidean_norm",
    "reduce_prod",
    "segment_max",
    "segment_mean",
    "segment_min",
    "segment_prod",
    "segment_sum",
    "sobol_sample",
    "softsign",
    "top_k",
    "unsorted_segment_max",
    "unsorted_segment_mean",
    "unsorted_segment_min",
    "unsorted_segment_prod",
    "unsorted_segment_sqrt_n",
    "unsorted_segment_sum",
]

# keep sorted
VULKAN_FAILING = [
    "acos",
    "argmax",
    "argmin",
    "asin",
    "asinh",
    "atan",
    "atan2",
    "bessel_i0",
    "bessel_i0e",
    "bessel_i1",
    "bessel_i1e",
    "betainc",
    "bincount",
    "confusion_matrix",
    "count_nonzero",
    "cumprod",
    "cumsum",
    "cumulative_logsumexp",
    "divide",  # Failing for integer inputs because iree doesn't output 'f64'.
    "equal",  # TODO(hanchung): Enable the test after integration.
    "erfinv",
    "greater",  # TODO(hanchung): Enable the test after integration.
    "greater_equal",  # TODO(hanchung): Enable the test after integration.
    "igamma",
    "igammac",
    "in_top_k",
    "invert_permutation",
    "is_inf",  # TODO(hanchung): Enable the test after integration.
    "is_nan",  # TODO(hanchung): Enable the test after integration.
    "is_non_decreasing",
    "is_strictly_increasing",
    "less",  # TODO(hanchung): Enable the test after integration.
    "less_equal",  # TODO(hanchung): Enable the test after integration.
    "logical_and",  # TODO(hanchung): Enable the test after integration.
    "logical_not",
    "logical_or",  # TODO(hanchung): Enable the test after integration.
    "logical_xor",
    "ndtri",
    "nextafter",
    "not_equal",  # TODO(hanchung): Enable the test after integration.
    "polygamma",
    "pow",
    "reduce_all",
    "reduce_any",
    "reduce_euclidean_norm",
    "reduce_prod",
    "segment_max",
    "segment_mean",
    "segment_min",
    "segment_prod",
    "segment_sum",
    "sign",
    "sobol_sample",
    "softsign",
    "top_k",
    "unsorted_segment_max",
    "unsorted_segment_mean",
    "unsorted_segment_min",
    "unsorted_segment_prod",
    "unsorted_segment_sqrt_n",
    "unsorted_segment_sum",
    "zeta",
]

# TODO(#5400): Failing only on Nvidia GPU.
# keep sorted
TURING_VULKAN_FAILING = [
    "digamma",
    "is_finite",
    "mod",
]

iree_e2e_cartesian_product_test_suite(
    name = "math_tests",
    failing_configurations = [
        {
            # Failing on llvm.
            "functions": LLVM_FAILING,
            "target_backends": "iree_llvmaot",
        },
        {
            # Failing on vulkan.
            "functions": VULKAN_FAILING + TURING_VULKAN_FAILING,
            "target_backends": "iree_vulkan",
        },
    ],
    matrix = {
        "src": "math_test.py",
        "reference_backend": "tf",
        "functions": TF_MATH_FUNCTIONS,
        "dynamic_dims": False,
        "test_complex": False,
        "target_backends": [
            "iree_llvmaot",
            "iree_vulkan",
        ],
    },
    deps = [
        "//third_party/py/absl:app",
        "//third_party/py/absl/flags",
        "//third_party/py/iree:pylib_tf_support",
        "//third_party/py/numpy",
        "//third_party/py/tensorflow",
        "//util/debuginfo:signalsafe_addr2line_installer",
    ],
)

# ---- DYNAMIC TESTS ---------------------------------------------- #

# keep sorted
LLVM_FAILING_DYNAMIC = [
    "acos",
    "angle",
    "argmax",
    "argmin",
    "asin",
    "atan",
    "atan2",
    "bessel_i0",
    "bessel_i0e",
    "bessel_i1",
    "bessel_i1e",
    "betainc",
    "bincount",
    "confusion_matrix",
    "count_nonzero",
    "cumprod",
    "cumsum",
    "cumulative_logsumexp",
    "divide",
    "erfinv",
    "expm1",
    "igamma",
    "igammac",
    "imag",
    "in_top_k",
    "invert_permutation",
    "is_non_decreasing",
    "is_strictly_increasing",
    "log1p",
    "log_sigmoid",
    "log_softmax",
    "ndtri",
    "nextafter",
    "reduce_euclidean_norm",
    "reduce_logsumexp",
    "reduce_mean",
    "reduce_prod",
    "reduce_std",
    "reduce_variance",
    "segment_max",
    "segment_mean",
    "segment_min",
    "segment_prod",
    "segment_sum",
    "sinh",
    "sobol_sample",
    "softplus",
    "top_k",
    "unsorted_segment_max",
    "unsorted_segment_mean",
    "unsorted_segment_min",
    "unsorted_segment_prod",
    "unsorted_segment_sqrt_n",
    "unsorted_segment_sum",
    "xlog1py",
    "xlogy",
    "zero_fraction",
]

# keep sorted
VULKAN_FAILING_DYNAMIC = [
    "abs",
    "accumulate_n",
    "acos",
    "acosh",
    "add",
    "add_n",
    "angle",
    "argmax",
    "argmin",
    "asin",
    "asinh",
    "atan",
    "atan2",
    "atanh",
    "bessel_i0",
    "bessel_i0e",
    "bessel_i1",
    "bessel_i1e",
    "betainc",
    "bincount",
    "ceil",
    "confusion_matrix",
    "cos",
    "cosh",
    "count_nonzero",
    "cumprod",
    "cumsum",
    "cumulative_logsumexp",
    "digamma",
    "divide",
    "divide_no_nan",
    "equal",
    "erf",
    "erfc",
    "erfinv",
    "exp",
    "expm1",
    "floor",
    "floordiv",
    "floormod",
    "greater",
    "greater_equal",
    "igamma",
    "igammac",
    "imag",
    "in_top_k",
    "invert_permutation",
    "is_finite",
    "is_inf",
    "is_nan",
    "is_non_decreasing",
    "is_strictly_increasing",
    "lbeta",
    "less",
    "less_equal",
    "lgamma",
    "log",
    "log1p",
    "log_sigmoid",
    "log_softmax",
    "logical_and",
    "logical_not",
    "logical_or",
    "logical_xor",
    "maximum",
    "minimum",
    "mod",
    "multiply",
    "multiply_no_nan",
    "ndtri",
    "negative",
    "nextafter",
    "not_equal",
    "polygamma",
    "polyval",
    "pow",
    "reciprocal",
    "reciprocal_no_nan",
    "reduce_all",
    "reduce_any",
    "reduce_euclidean_norm",
    "reduce_logsumexp",
    "reduce_max",
    "reduce_mean",
    "reduce_min",
    "reduce_prod",
    "reduce_std",
    "reduce_sum",
    "reduce_variance",
    "rint",
    "round",
    "rsqrt",
    "scalar_mul",
    "segment_max",
    "segment_mean",
    "segment_min",
    "segment_prod",
    "segment_sum",
    "sigmoid",
    "sign",
    "sin",
    "sinh",
    "sobol_sample",
    "softmax",
    "softplus",
    "softsign",
    "sqrt",
    "square",
    "squared_difference",
    "subtract",
    "tan",
    "tanh",
    "top_k",
    "truediv",
    "unsorted_segment_max",
    "unsorted_segment_mean",
    "unsorted_segment_min",
    "unsorted_segment_prod",
    "unsorted_segment_sqrt_n",
    "unsorted_segment_sum",
    "xdivy",
    "xlog1py",
    "xlogy",
    "zero_fraction",
    "zeta",
]

iree_e2e_cartesian_product_test_suite(
    name = "math_dynamic_dims_tests",
    failing_configurations = [
        {
            # Failing on llvm.
            "functions": LLVM_FAILING_DYNAMIC,
            "target_backends": "iree_llvmaot",
        },
        {
            # Failing on vulkan.
            "functions": VULKAN_FAILING_DYNAMIC,
            "target_backends": "iree_vulkan",
        },
    ],
    matrix = {
        "src": "math_test.py",
        "reference_backend": "tf",
        "functions": TF_MATH_FUNCTIONS,
        "dynamic_dims": True,
        "test_complex": False,
        "target_backends": [
            "iree_llvmaot",
            "iree_vulkan",
        ],
    },
    deps = [
        "//third_party/py/absl:app",
        "//third_party/py/absl/flags",
        "//third_party/py/iree:pylib_tf_support",
        "//third_party/py/numpy",
        "//third_party/py/tensorflow",
        "//util/debuginfo:signalsafe_addr2line_installer",
    ],
)

# ---- COMPLEX TESTS ---------------------------------------------- #

# This list was generated by running:
#   bazel run integrations/tensorflow/e2e/math:math_test_manual -- --list_functions_with_complex_tests
# keep sorted
COMPLEX_FUNCTIONS = [
    "abs",
    "add",
    "angle",
    "asinh",
    "atanh",
    "conj",
    "cos",
    "cosh",
    "count_nonzero",
    "cumprod",
    "cumsum",
    "divide",
    "divide_no_nan",
    "exp",
    "expm1",
    "imag",
    "l2_normalize",
    "log",
    "log1p",
    "multiply",
    "multiply_no_nan",
    "negative",
    "pow",
    "real",
    "reciprocal",
    "reciprocal_no_nan",
    "reduce_euclidean_norm",
    "reduce_std",
    "reduce_variance",
    "rsqrt",
    "sigmoid",
    "sign",
    "sin",
    "sinh",
    "sqrt",
    "square",
    "squared_difference",
    "subtract",
    "tan",
    "tanh",
    "truediv",
    "xdivy",
    "xlog1py",
    "xlogy",
    "zero_fraction",
]

# keep sorted
LLVM_FAILING_COMPLEX = [
    "angle",
    "asinh",
    "atanh",
    "cos",
    "cosh",
    "count_nonzero",
    "cumprod",
    "cumsum",
    "divide",
    "divide_no_nan",
    "expm1",
    "log",
    "log1p",
    "multiply_no_nan",
    "negative",
    "pow",
    "reciprocal",
    "reciprocal_no_nan",
    "reduce_euclidean_norm",
    "reduce_std",
    "reduce_variance",
    "rsqrt",
    "sigmoid",
    "sign",
    "sin",
    "sinh",
    "sqrt",
    "tan",
    "tanh",
    "xdivy",
    "xlog1py",
    "xlogy",
    "zero_fraction",
]

# keep sorted
VULKAN_FAILING_COMPLEX = [
    "angle",
    "asinh",
    "atanh",
    "cos",
    "cosh",
    "count_nonzero",
    "cumprod",
    "cumsum",
    "divide",
    "divide_no_nan",
    "expm1",
    "log",
    "log1p",
    "multiply_no_nan",
    "negative",
    "pow",
    "reciprocal",
    "reciprocal_no_nan",
    "reduce_euclidean_norm",
    "reduce_std",
    "reduce_variance",
    "rsqrt",
    "sigmoid",
    "sign",
    "sin",
    "sinh",
    "sqrt",
    "tan",
    "tanh",
    "xdivy",
    "xlog1py",
    "xlogy",
    "zero_fraction",
]

iree_e2e_cartesian_product_test_suite(
    name = "math_complex_tests",
    failing_configurations = [
        {
            # Failing on llvm.
            "functions": LLVM_FAILING_COMPLEX,
            "target_backends": "iree_llvmaot",
        },
        {
            # Failing on vulkan.
            "functions": VULKAN_FAILING_COMPLEX,
            "target_backends": "iree_vulkan",
        },
    ],
    matrix = {
        "src": "math_test.py",
        "reference_backend": "tf",
        "functions": COMPLEX_FUNCTIONS,
        "dynamic_dims": False,
        "test_complex": True,
        "target_backends": [
            "iree_llvmaot",
            "iree_vulkan",
        ],
    },
    deps = [
        "//third_party/py/absl:app",
        "//third_party/py/absl/flags",
        "//third_party/py/iree:pylib_tf_support",
        "//third_party/py/numpy",
        "//third_party/py/tensorflow",
        "//util/debuginfo:signalsafe_addr2line_installer",
    ],
)
